Transformer model distillation


Transformer models which were pre-trained on large corpora, such as BERT/XLNet/XLM, have shown to improve the accuracy of many NLP tasks. However, such models have two distinct disadvantages - (1) model size and (2) speed, since such large models are computationally heavy.

One possible approach to overcome these cons is to use Knowledge Distillation (KD). Using this approach a large model is trained on the data set and then used to teach a much smaller and more efficient network. This is often referred to a Student-Teacher training where a teacher network adds its error to the student’s loss function, thus, helping the student network to converge to a better solution.

Knowledge Distillation

One approach is similar to the method in Hinton 2015 [1]. The loss function is modified to include a measure of distributions divergence, which can be measured using KL divergence or MSE between the logits of the student and the teacher network.

\(loss = w_s \cdot loss_{student} + w_d \cdot KL(logits_{student} / T || logits_{teacher} / T)\)

where T is a value representing temperature for softening the logits prior to applying softmax. loss_{student} is the original loss of the student network obtained during regular training. Finally, the losses are weighted.


This class can be added to support for distillation in a model. To add support for distillation, the student model must include handling of training using TeacherStudentDistill class, see nlp_architect.procedures.token_tagging.do_kd_training for an example how to train a neural tagger using a transformer model using distillation.

class nlp_architect.nn.torch.distillation.TeacherStudentDistill(teacher_model: nlp_architect.models.TrainableModel, temperature: float = 1.0, dist_w: float = 0.1, loss_w: float = 1.0, loss_function='kl')[source]

Teacher-Student knowledge distillation helper. Use this object when training a model with KD and a teacher model.

  • teacher_model (TrainableModel) – teacher model
  • temperature (float, optional) – KD temperature. Defaults to 1.0.
  • dist_w (float, optional) – distillation loss weight. Defaults to 0.1.
  • loss_w (float, optional) – student loss weight. Defaults to 1.0.
  • loss_function (str, optional) – loss function to use (kl for KLDivLoss, mse for MSELoss)
static add_args(parser: argparse.ArgumentParser)[source]

Add KD arguments to parser

Parameters:parser (argparse.ArgumentParser) – parser
distill_loss(loss, student_logits, teacher_logits)[source]

Add KD loss

  • loss – student loss
  • student_logits – student model logits
  • teacher_logits – teacher model logits

KD loss


Get teacher logits

Parameters:inputs – input
Returns:teachr logits

Supported models


Useful for training taggers from Transformer models. NeuralTagger model that uses LSTM and CNN based embedders are ~3M parameters in size (~30-100x smaller than BERT models) and ~10x faster on average.


  1. Train a transformer tagger using TransformerTokenClassifier or using nlp_architect train transformer_token command
  2. Train a neural tagger Neural Tagger using the trained transformer model and use the TeacherStudentDistill model that was configured with the transformer model. This can be done using Neural Tagger’s train loop or by using nlp_architect train tagger_kd command


More models supporting distillation will be added in next releases

[1]Distilling the Knowledge in a Neural Network: Geoffrey Hinton, Oriol Vinyals, Jeff Dean,